package db

import (
    "context"
    "fmt"
    "sync"
    "time"

    logger "gitee.com/allan577/go-lib-logger"
    "github.com/go-sql-driver/mysql"
    "github.com/jinzhu/gorm"
    _ "github.com/jinzhu/gorm/dialects/postgres"
    "github.com/opentracing/opentracing-go"
)

// ----------------------------------------
//  GORM 客户端
// ----------------------------------------
type GORMClient struct {
    name        string
    config      *ORMConfig
    mu          *sync.Mutex
    client      *gorm.DB
    connectedAt time.Time
    pingedAt    time.Time
}

var gormCallbacksOnce = &sync.Once{}

// 创建 GORM 客户端实例
func CreateGORMClient(name string, config *ORMConfig) (*GORMClient, error) {
    client := &GORMClient{name: name, config: config.Copy(), mu: &sync.Mutex{}}
    if err := client.Connect(); err != nil {
        return nil, err
    } else {
        return client, nil
    }
}

// 获取当前 GORM 客户端实例名称
func (client *GORMClient) Name() string {
    return client.name
}

// 获取当前 GORM 客户端的第三方客户端实例
func (client *GORMClient) Client() *gorm.DB {
    return client.client
}

// 设置日志
func (client *GORMClient) SetLogger(log logger.Log) {
    client.client.SetLogger(log)
}

// 设置上下文
func (client *GORMClient) WithCtx(ctx context.Context) *gorm.DB {
    if ctx == nil {
        return client.client
    }
    parentSpan := opentracing.SpanFromContext(ctx)
    if parentSpan == nil {
        return client.client
    }
    gormCallbacksOnce.Do(func() {
        AddGormCallbacks(client.client)
    })
    return client.client.Set(parentSpanGormKey, parentSpan)
}

// 检查远程服务的可用性
func (client *GORMClient) Ping() error {
    if client.config.Ping && client.client != nil {
        client.pingedAt = time.Now().Local()
        return client.client.DB().Ping()
    } else {
        return nil
    }
}

// 连接远程服务
func (client *GORMClient) Connect() error {
    client.mu.Lock()
    defer client.mu.Unlock()
    if client.client != nil {
        return client.Ping()
    }
    var dsn string
    if client.config.Driver=="postgres"{
        dsn=fmt.Sprintf(
            "host=%s user=%s dbname=%s sslmode=disable password=%s",
            client.config.Host,
            client.config.Username,
            client.config.Database,
            client.config.Password)
    }else{
        c := mysql.NewConfig() // 用于构建 DSN
        c.Loc = time.Local
        c.User = client.config.Username
        c.Passwd = client.config.Password
        c.Net = "tcp"
        c.Addr = fmt.Sprintf("%s:%d", client.config.Host, client.config.Port)
        c.DBName = client.config.Database
        c.Collation = "utf8mb4_general_ci"
        c.ParseTime = true
        dsn=c.FormatDSN()
    }

    if db, err := gorm.Open(client.config.Driver,dsn); err != nil {
        return err
    } else {
        db.DB().SetMaxOpenConns(client.config.MaxConn)
        db.DB().SetMaxIdleConns(client.config.MaxIdleConn)
        // db.DB().SetConnMaxLifetime(time.Second * 300)
        client.client = db
        client.connectedAt = time.Now().Local()
        return client.Ping()
    }
}

// 关闭远程服务连接
func (client *GORMClient) Close() error {
    client.mu.Lock()
    defer client.mu.Unlock()
    if client.client == nil {
        return nil
    } else {
        err := client.client.Close()
        client.client = nil // 强制丢弃
        return err
    }
}

// 重连远程服务
func (client *GORMClient) Reconnect() error {
    _ = client.Close()
    return client.Connect()
}
